package idlrpc

import (
	"context"
	"errors"
	"fmt"
	"gitee.com/CloudGuan/rpc-go-backend/idlrpc/common"
	"gitee.com/CloudGuan/rpc-go-backend/idlrpc/protocol"
	"gitee.com/CloudGuan/rpc-go-backend/idlrpc/stub"
	"gitee.com/CloudGuan/rpc-go-backend/idlrpc/stubcall"
	"gitee.com/CloudGuan/rpc-go-backend/idlrpc/transport"
	"google.golang.org/protobuf/proto"
	"sync"
	"time"
)

var (
	srvCallMgr    *stubcall.SrvStubCallMgr    // 远端调用存根
	clientCallMgr *stubcall.ClientStubCallMgr //客户端调用存根
	srvStubMgr    *stub.StubManager           // 服务调用stub
	clientStubMgr *stub.ClientStubMgr         // 客户端调用stub
	once          sync.Once
)

func init() {
	once.Do(func() {
		//初始化日志库
		//初始化错误库
		//初始化rpc
		initRpc()
	})
}

func initRpc() {
	srvCallMgr = stubcall.NewSrvStubCallMgr()
	clientCallMgr = stubcall.NewClientStubCallMgr()
	srvStubMgr = stub.NewServerStubMgr()
	clientStubMgr = stub.NewClientStubMgr()
}

func GetServerStubMgr() *stub.StubManager {
	return srvStubMgr
}

func GetClientStubMgr() *stub.ClientStubMgr {
	return clientStubMgr
}

func GetStubCallMgr() *stubcall.SrvStubCallMgr {
	return srvCallMgr
}

func Start() {
	srvStubMgr.Start()
}

//消息处理逻辑
func OnMessage(trans transport.Transport, ctx context.Context) {
	//一次性读取多个消息
	for {
		headers, len, err := trans.Peek(8)
		if len != 8 || err != nil {
			return
		}

		header := protocol.ReadHeader(headers)
		if header == nil {
			return
		}

		if header.Length == 0 {
			trans.Close()
			return
		}

		//头部包还没有收集满这里还要等待网络线程返回数据包
		if header.Length > trans.Size() {
			return
		}

		switch header.Type {
		case protocol.REQUEST_MSG:
			if onCall(trans, ctx) != nil {
				//TODO add log and close trans ?????
				return
			}
		case protocol.RESPONESE_MSG:
			if onReturn(trans, ctx) != nil {
				//TODO add log and close trans ?????
				return
			}
		case protocol.NOTRPC_MSG:
		default:
			return
		}
	}
}

func onCall(trans transport.Transport, ctx context.Context) error {
	//read trans header
	pkg := make([]byte, protocol.CallHeadSize)
	if len, err := trans.Read(pkg[:], protocol.CallHeadSize); len != protocol.CallHeadSize || err != nil {
		return errors.New("Read Call Header Error!!!!!!!")
	}

	reqheader := protocol.ReadCallHeader(pkg)
	if reqheader == nil {
		return errors.New("Read Req Header error")
	}

	stub := srvStubMgr.GetService(reqheader.ServiceUUID, reqheader.ServerID)
	if stub == nil {
		notFound(trans, reqheader)
		return errors.New("NOT FOUND SERVICE")
	}

	msglen := int(reqheader.Length) - protocol.CallHeadSize

	reqpkg := &protocol.RequestPackage{
		reqheader,
		make([]byte, msglen),
	}

	reallen, err := trans.Read(reqpkg.Buffer[:], msglen)
	if err != nil || reallen != msglen {
		return errors.New("read message body error ")
	}

	stub.CallStub(trans, reqpkg)
	return nil
}

func onReturn(trans transport.Transport, ctx context.Context) error {

	pkg := make([]byte, protocol.RespHeadSize)
	if len, err := trans.Read(pkg, protocol.RespHeadSize); len != protocol.RespHeadSize || err != nil {
		return errors.New("Read Resp Header Error!!!!!!!")
	}

	respheader := protocol.ReadRetHeader(pkg)
	if respheader == nil {
		return errors.New("Read Resp Header error")
	}

	call := clientCallMgr.GetClientStubCall(respheader.CallID)
	if call == nil {
		return errors.New("Get SrvStub Cache Error ！！！！")
	}

	call.SetErrorCode(respheader.ErrorCode)

	cstub := clientStubMgr.GetClientStub(call.GetStubID())
	if cstub == nil {
		clientCallMgr.Destory(respheader.CallID)
		return errors.New("Get SrvStub Cache Error ！！！！")
	}

	switch respheader.ErrorCode {
	case protocol.IDL_SUCCESS:
		cstub.SetSrvID(respheader.ServerID)
	case protocol.IDL_SERVICE_NOT_FOUND:
		cstub.SetSrvID(common.INVALIED_STUB_ID)
	case protocol.IDL_SERVICE_ERROR:
		cstub.SetSrvID(common.INVALIED_STUB_ID)
	case protocol.IDL_RPC_TIME_OUT:
		cstub.SetSrvID(common.INVALIED_STUB_ID)
	}

	bodylen := int(respheader.Length) - protocol.RespHeadSize
	//get resp data
	resp := &protocol.ResponsePackage{
		respheader,
		make([]byte, bodylen),
	}

	reslen, err := trans.Read(resp.Buffer, bodylen)
	if err != nil || reslen != bodylen {
		call.SetErrorCode(protocol.IDL_SERVICE_ERROR)
	}
	// 发回等待 协程
	go call.Ret(resp)
	return nil
}

//这里是被调用的服务，不需要关心trans call需要关系
func registerStub(srvstub stub.SrvStub) {
	stubptr := stub.NewServerStub(srvstub, srvCallMgr)
	if stubptr == nil {
		//TODO add log
		return
	}

	srvStubMgr.AddService(stubptr)
}

//@title 注册服务的实现到stub
func RegisterService(uuid uint64, srvstub stub.SrvStub, impl interface{}) error {
	if srvStubMgr == nil {
		panic("Register service error, stub manager not init yet !!!!")
	}
	//尝试去拿服务
	if uuid == 0 {
		return errors.New("Invalid uuid!")
	}

	if uuid != srvstub.GetUUID() {
		return errors.New("implement not match with service stub ！！！！")
	}
	registerStub(srvstub)

	stubptr := srvStubMgr.GetService(uuid, 0)
	if stubptr == nil {
		return errors.New("[Rpc] Unkown service type")
	}

	stubptr.RegisterService(impl)
	return nil
}

//@title 客户端调用rpc使用
//@detail 阻塞方法
func Call(stubid, methodid, retry, timeout uint32, message proto.Message) (resp *protocol.ResponsePackage, err error) {
	//获取 clientstub
	if clientStubMgr == nil {
		return nil, errors.New("Service not Found!!")
	}

	if stubid == 0 {
		return nil, errors.New("Service Id is not invalid!!!")
	}

	cstub := clientStubMgr.GetClientStub(stubid)
	if cstub == nil {
		return nil, errors.New("Proxy not found!!")
	}

	stubcall := stubcall.NewClientStubCall(stubid, retry, timeout)
	if stubcall == nil {
		return nil, errors.New("Create StubCall error")
	}

	clientCallMgr.Add(stubcall)
	defer func() {
		clientCallMgr.Destory(stubcall.GetCallID())
	}()

	startT := time.Now()
	resp, err = cstub.CallMetchod(methodid, stubcall, message)
	if err != nil {
		//如果有错误应该直接返回错误，如果是rpc 正确发送对端的错误是后面处理
		return resp, err
	}

	//因为有oneway 关键字 所以是可能没有返回值得直接返回就是了
	if cstub.IsOneway(methodid) {
		//TODO 添加网络错误的判断以及异常的处理
		return nil, err
	}

	//对于超时的情况才重试，其他错误返回给上层进行处理
	for stubcall.GetRetryTime() > 0 {
		if resp.Header.ErrorCode == protocol.IDL_RPC_TIME_OUT {
			resp = nil
			stubcall.DecrRetryTime()
			resp, err = cstub.Retry(stubcall)
		} else {
			break
		}
	}

	//TODO 添加自己的错误库 抛出和服务器同类的错误异常
	switch resp.Header.ErrorCode {
	case protocol.IDL_SUCCESS:
	case protocol.IDL_SERVICE_NOT_FOUND:
		err = errors.New(fmt.Sprintf("[idlrpc] %d,%s,%d  service not found!", cstub.GetSrvUUID(), cstub.GetSrvName(), methodid))
	case protocol.IDL_SERVICE_ERROR:
		err = errors.New("Idl service error!")
	case protocol.IDL_RPC_TIME_OUT:
		err = errors.New(fmt.Sprintf("[idlrpc] %d,%s,%d  service call timeout use time: %v!", cstub.GetSrvUUID(), cstub.GetSrvName(), methodid, time.Since(startT)))
	default:
	}

	return
}

//@title 添加服务
//@detail TODO 这里应该是需要从zookeeper获取对应是ip端口的 没有的话第一次要主动建立
func RegisterProxy(proxy stub.ProxyStub) error {
	client := stub.NewClientStub(proxy)
	if client == nil {
		return errors.New("add stub error !!!")
	}
	//里面会赋值对应的id
	clientStubMgr.AddClientStub(client)
	proxy.SetClientID(client.GetStubID())
	return nil
}

func Tick() {
	//更新server mgr
	srvStubMgr.Tick()
	//更新server stub call
	srvCallMgr.Tick()
}

func notFound(trans transport.Transport, req *protocol.RpcCallHeader) {
	resp := protocol.BuildNotFount(req)
	if resp == nil {
		return
	}
	//TODO 添加序列化函数
	resppkg, pkglen := protocol.PackRespMsg(resp)
	if resppkg == nil || pkglen == 0 {
		//TODO 添加序列化错误
		return
	}
	trans.Send(resppkg)
}
